from cmath import inf
from heapq import heappop, heappush
from copy import deepcopy
from math import ceil
from random import choice, random, sample
from typing import Tuple
from agh_data import AGHData
from agh_obj import *
import toolbox

MAX = 999999999999
MINUTE = 60

COST_EARLY = 10
COST_LATE = 20

TIME_WIN_WIDTH = 10
TIME_STEP = 30

POP_NUM = 300 
CROSS_RATE = 0.9
MUTATION_RATE = 0.01
ITERA_TIMES = 500
PRESERVED_NUM = 20
SELECT_NUM = 20

class Operation():
    def __init__(self, time, vehicle_type, vehicle_index, gate_index, flight_index, distance, source_index, is_soucre_a_park) -> None:
        self.time:int = time
        self.vehicle_type:VehicleType = vehicle_type
        self.vehicle_index:int = vehicle_index
        self.gate_index:int = gate_index
        self.flight_index:int = flight_index
        self.distance:float = distance
        self.source_index:int = source_index
        self.is_source_a_park:bool = is_soucre_a_park
    
    def __lt__(self, other) -> bool:
        return self.time < other.time
    
    def __str__(self) -> str:
        return f"{Clock.get_time_str(self.time)} {self.vehicle_type} {self.gate_index} {self.flight_index} {self.vehicle_index}"

class VehicleInfo():
    def __init__(self, capcity:int, park_index:int, vtype:VehicleType) -> None:
        self.complete_time:int = 0
        self.capcity:int = capcity
        self.is_source_a_park:bool = True
        self.gate_index:int = -1
        self.park_index:int = park_index
        self.type:VehicleType = vtype
    
    def consume_capcity(self, gate_index:int, complete_time, back_dis) -> float:
        self.capcity -= 1
        cost = 0
        if self.capcity == 0:
            cost_time = Vehicle.property_of_type[self.type].recover_time_cost \
                    + ceil(back_dis/Vehicle.property_of_type[self.type].speed)
            self.complete_time = complete_time + cost_time
            self.capcity = Vehicle.property_of_type[self.type].capcity
            self.is_source_a_park = True
            cost += back_dis * Vehicle.property_of_type[self.type].cost
        else:
            self.complete_time = complete_time
            self.gate_index = gate_index
            self.is_source_a_park = False
        return cost
        
class DispatchMethod():
    def __init__(self) -> None:
        self.name:str = ""
        self.far_gate_arrival_time_windows:List[Tuple[int, int]] = None
        self.near_gate_arrival_time_windows:List[Tuple[int, int]] =None

    def generate_dispath_plan(self, agh_data:AGHData) -> Tuple[bool, List[Operation], str]:
        pass

    def generate_vehicle_infos(self, vehicle_data:List[List[Vehicle]]) -> List[List[VehicleInfo]]:
        return [[VehicleInfo(Vehicle.property_of_type[vtype].capcity, v.park_index, vtype)\
                             for v in vehicle_data[vtype.value]] for vtype in VehicleType]

    def construct_arrival_time_windows(self):
        arrival_time_windows = [(-1, -1)] * len(VEHICLE_NAMES)
        for vtype in VehicleType:
            start_time = - TIME_WIN_WIDTH / 2  * MINUTE
            for pre_type in VehicleType.get_pre_service_car_type(vtype):
                tmp_time = arrival_time_windows[pre_type.value][0] + Vehicle.property_of_type[pre_type].service_time_cost
                start_time = max(start_time, tmp_time)
            arrival_time_windows[vtype.value] = (start_time, start_time + TIME_WIN_WIDTH * MINUTE)
        self.far_gate_arrival_time_windows = arrival_time_windows

        
        arrival_time_windows = [(-1, -1)] * len(VEHICLE_NAMES)
        for vtype in VehicleType:
            if vtype == VehicleType.ShuttleBus:
                continue
            start_time = - TIME_WIN_WIDTH / 2  * MINUTE
            for pre_type in VehicleType.get_pre_service_car_type(vtype):
                if pre_type == VehicleType.ShuttleBus:
                    continue
                tmp_time = arrival_time_windows[pre_type.value][0] + Vehicle.property_of_type[pre_type].service_time_cost
                start_time = max(start_time, tmp_time)
            arrival_time_windows[vtype.value] = (start_time, start_time + TIME_WIN_WIDTH * MINUTE)
        self.near_gate_arrival_time_windows = arrival_time_windows
    
    def get_arrival_time_border(self, vtype:VehicleType, flight_arrival_time:int, is_near_gate:bool) -> Tuple[int, int]:
        arrival_time_windows = self.near_gate_arrival_time_windows if is_near_gate else self.far_gate_arrival_time_windows
        earliest_arrival_time, latest_arrival_time = arrival_time_windows[vtype.value]
        earliest_arrival_time += flight_arrival_time
        latest_arrival_time += flight_arrival_time
        return earliest_arrival_time, latest_arrival_time
    
    def get_pre_service_end_time(self, default_value:int, complete_times:List[int], vtype:VehicleType, is_near_gate:bool) -> int:
        pre_service_end_time = default_value
        for pre_type in VehicleType.get_pre_service_car_type(vtype):
            if is_near_gate and pre_type == VehicleType.ShuttleBus:
                continue
            pre_service_end_time = max(pre_service_end_time, complete_times[pre_type.value])
        return pre_service_end_time

class FCFS(DispatchMethod):
    def __init__(self) -> None:
        super().__init__()
        self.name:str = "先来先服务"

    def generate_dispath_plan(self, agh_data:AGHData) -> Tuple[bool, List[Operation], str]:
        vehicle_infos = self.generate_vehicle_infos(agh_data.vehicle_data)
        plan = []
        self.construct_arrival_time_windows()
        def allocate_vehicle_for_flight(gate_index, flight_index, flight_arrival_time):
            complete_times = [0] * 7
            
            is_near_gate = agh_data.gate_data[gate_index].is_near_gate

            for vtype in VehicleType:
                if is_near_gate and vtype == VehicleType.ShuttleBus:
                    continue

                pre_service_end_time = \
                    self.get_pre_service_end_time(flight_arrival_time, complete_times, vtype, is_near_gate)

                chosed_v_index =  -1                
                chosed_v_from_gate = -1
                chosed_v_from_park = -1

                _, latest_arrival_time = self.get_arrival_time_border(vtype, flight_arrival_time, is_near_gate)
                earliest_complete_time = flight_arrival_time
                
                for index, info in enumerate(vehicle_infos[vtype.value]):
                    if not info.is_source_a_park:
                        dis = agh_data.get_distance_between_gates(info.gate_index, gate_index)
                        if dis == -1:
                            continue
                        vehicle_arrival_time = info.complete_time + ceil(dis/Vehicle.property_of_type[vtype].speed)

                        if vehicle_arrival_time <= latest_arrival_time and info.complete_time < earliest_complete_time:
                            chosed_v_from_gate = index
                            earliest_complete_time = info.complete_time

                if chosed_v_from_gate == -1:
                    mindis = inf
                    for index, info in enumerate(vehicle_infos[vtype.value]):
                        if info.is_source_a_park:
                            dis = agh_data.get_distance_park_to_gate(info.park_index, gate_index)
                            vehicle_lanuch_time = pre_service_end_time - ceil(dis/Vehicle.property_of_type[vtype].speed)

                            if vehicle_lanuch_time >= info.complete_time and dis < mindis:
                                chosed_v_from_park = index
                                mindis = dis

                    if chosed_v_from_park == -1:
                        return False, [], f"没有足够的{VEHICLE_NAMES[vtype.value]}"

                    chosed_v_index = chosed_v_from_park
                else:
                    chosed_v_index = chosed_v_from_gate
                
                vehicle_info = vehicle_infos[vtype.value][chosed_v_index]

                if vehicle_info.is_source_a_park:
                    dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                    vehicle_lanuch_time = pre_service_end_time - ceil(dis/Vehicle.property_of_type[vtype].speed)
                    vehicle_arrival_time = pre_service_end_time
                    source = vehicle_info.park_index
                else:
                    dis = agh_data.get_distance_between_gates(vehicle_info.gate_index, gate_index)
                    vehicle_lanuch_time = vehicle_info.complete_time
                    vehicle_arrival_time = vehicle_info.complete_time + ceil(dis/Vehicle.property_of_type[vtype].speed)
                    source = vehicle_info.gate_index

                heappush(plan, Operation(vehicle_lanuch_time, vtype, chosed_v_index, gate_index, flight_index, \
                    dis, source, vehicle_info.is_source_a_park))
                
                service_start_time = max(pre_service_end_time, vehicle_arrival_time)
                complete_time = service_start_time + Vehicle.property_of_type[vtype].service_time_cost
                complete_times[vtype.value] = complete_time

                back_dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                vehicle_info.consume_capcity(gate_index, complete_time, back_dis)

        progress_dialog = toolbox.open_progress_dialog(f"生成{self.name}方案中", len(agh_data.flight_data))

        for index, flight in enumerate(agh_data.flight_data):
            allocate_vehicle_for_flight(flight.gate_index, index, flight.arrival_time_stamp)
            progress_dialog.setValue(index + 1)
            if progress_dialog.wasCanceled():
                return False, [], ""
        
        solution = []
        while len(plan) != 0:
            solution.append(heappop(plan))
        return True, solution, ""

class GA(DispatchMethod):
    class Chromosome():
        def __init__(self, data) -> None:
            self.data = data
            self.fitness:float = -1

    def __init__(self) -> None:
        super().__init__()
        self.name = "遗传算法"
        self.pop_num = POP_NUM 
        self.cross_rate = CROSS_RATE
        self.mutation_rate = MUTATION_RATE
        self.itera_times = ITERA_TIMES

    def generate_dispath_plan(self, agh_data:AGHData) -> Tuple[bool, List[Operation], str]:
        self.construct_arrival_time_windows()
        chromosomes = self.generate_chromosomes(self.pop_num, agh_data)
        best_chromosome =  max(chromosomes, key=lambda c:c.fitness)

        progress_dialog = toolbox.open_progress_dialog(f"生成{self.name}方案中", self.itera_times)

        best_fitness = []
        i = 0
        while i < self.itera_times:
            selected_chromosomes = self.select(chromosomes)
            crossed_chromosomes = self.crossover(selected_chromosomes, agh_data)

            chromosomes.sort(reverse=True, key=lambda c:c.fitness)
            chromosomes = chromosomes[:self.pop_num - len(crossed_chromosomes)] + crossed_chromosomes
            chromosomes = self.mutation(chromosomes, agh_data)

            cur_best_chromosome = max(chromosomes + [best_chromosome], key=lambda c:c.fitness)
            if cur_best_chromosome.fitness < best_chromosome.fitness:
                chromosomes.append(best_chromosome)
            else:
                best_chromosome = cur_best_chromosome

            best_fitness.append(cur_best_chromosome.fitness)

            i += 1
            progress_dialog.setValue(i)
            if progress_dialog.wasCanceled():
                return False, [], ""
        
        toolbox.open_chart_dialog([best_fitness], ["test"], "fitness", 'itertimes', 'fitness')

        plan = self.translate(best_chromosome, agh_data)
        if len(plan) == 0:
            return False, [], "未找到可行解"

        return True, plan, ""

    def generate_chromosomes(self, pop_num:int, agh_data:AGHData) -> List[Chromosome]:
        chromosomes = []
        vehicle_nums = [len(v) for v in agh_data.vehicle_data]   
        for _ in range(pop_num):
            data = []
            vehicle_infos = self.generate_vehicle_infos(agh_data.vehicle_data)    
            cost = 0
            for flight in agh_data.flight_data:

                flight_arrival_time = flight.arrival_time_stamp
                gate_index = flight.gate_index
                complete_times = [0] * 7
                vehicles = []
                is_near_gate = agh_data.gate_data[gate_index].is_near_gate
                for vtype in VehicleType:
                    if is_near_gate and vtype == VehicleType.ShuttleBus:
                        vehicles.append((-1, -1))
                        continue

                    e, l = self.get_arrival_time_border(vtype, flight_arrival_time, is_near_gate)

                    while True:
                        vehicle_index = randint(0, vehicle_nums[vtype.value] - 1)
                        vehicle_arrival_time = randint(e, l)
                        vehicle_info = vehicle_infos[vtype.value][vehicle_index]

                        if vehicle_info.is_source_a_park:
                            dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                        else:
                            dis = agh_data.get_distance_between_gates(vehicle_info.gate_index, gate_index)
                        
                        if dis == -1:
                            continue

                        time_cost = ceil(dis/Vehicle.property_of_type[vtype].speed)
                        vehicle_lanuch_time = vehicle_arrival_time - time_cost

                        if vehicle_lanuch_time >= vehicle_info.complete_time:
                            break

                    if vehicle_arrival_time > flight_arrival_time:
                        cost += COST_LATE * (vehicle_arrival_time - flight_arrival_time)
                    else:
                        cost += COST_EARLY * (flight_arrival_time - vehicle_arrival_time)

                    pre_service_end_time = \
                        self.get_pre_service_end_time(max(flight_arrival_time, vehicle_arrival_time), complete_times, vtype, is_near_gate)

                    complete_time = pre_service_end_time + Vehicle.property_of_type[vtype].service_time_cost
                    complete_times[vtype.value] = complete_time

                    back_dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                    cost += vehicle_info.consume_capcity(gate_index, complete_time, back_dis)

                    vehicles.append((vehicle_index, vehicle_arrival_time))
                    cost += dis * Vehicle.property_of_type[vtype].cost
                data.append(vehicles)

            new_chromosome = GA.Chromosome(data)
            new_chromosome.fitness = 1/cost
            chromosomes.append(new_chromosome)
        return chromosomes

    def get_fitness(self, chromosome:Chromosome, agh_data:AGHData) -> float:
        vehicle_infos = self.generate_vehicle_infos(agh_data.vehicle_data)
        cost = 0

        for flight_index, op_infos in enumerate(chromosome.data):
            flight = agh_data.flight_data[flight_index]
            flight_arrival_time = flight.arrival_time_stamp
            gate_index = flight.gate_index
            complete_times = [0] * 7
            is_near_gate = agh_data.gate_data[gate_index].is_near_gate

            for vtype in VehicleType:
                if is_near_gate and vtype == VehicleType.ShuttleBus:
                    continue

                vehicle_index = op_infos[vtype.value][0]
                vehicle_arrival_time = op_infos[vtype.value][1]
                vehicle_info = vehicle_infos[vtype.value][vehicle_index]

                if vehicle_info.is_source_a_park:
                    dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                else:
                    dis = agh_data.get_distance_between_gates(vehicle_info.gate_index, gate_index)
                
                if dis == -1:
                    cost += MAX
                    dis = 0

                time_cost = ceil(dis/Vehicle.property_of_type[vtype].speed)
                vehicle_lanuch_time = vehicle_arrival_time - time_cost

                if vehicle_lanuch_time < vehicle_info.complete_time:
                    cost += MAX

                if vehicle_arrival_time > flight_arrival_time:
                    cost += COST_LATE * (vehicle_arrival_time - flight_arrival_time)
                else:
                    cost += COST_EARLY * (flight_arrival_time - vehicle_arrival_time)

                pre_service_end_time = \
                    self.get_pre_service_end_time(max(flight_arrival_time, vehicle_arrival_time), complete_times, vtype, is_near_gate)

                complete_time = pre_service_end_time + Vehicle.property_of_type[vtype].service_time_cost
                complete_times[vtype.value] = complete_time
                
                back_dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                cost += vehicle_info.consume_capcity(gate_index, complete_time, back_dis)
                cost += dis * Vehicle.property_of_type[vtype].cost
            
        return 1/cost

    def select(self, chromosomes:List[Chromosome]) -> List[Chromosome]:
        selected_chromosomes = []
        k = int(self.pop_num / 2)
        for _ in range(SELECT_NUM):
            selected_chromosomes.append(max(sample(chromosomes, k), key=lambda c:c.fitness))
        return selected_chromosomes

    def crossover(self, chromosomes:List[Chromosome], agh_data:AGHData) -> List[Chromosome]:
        crossed_chromosomes = []
        flight_num = len(agh_data.flight_data)
        for father in chromosomes:
            if random() < self.cross_rate:
                mother = choice(chromosomes)
                child_1, child_2 = deepcopy(father), deepcopy(mother)

                pos = randint(0, flight_num - 1)
                child_1.data[pos], child_2.data[pos] = child_1.data[pos], child_2.data[pos]

                child_1.fitness = self.get_fitness(child_1, agh_data)
                child_2.fitness = self.get_fitness(child_2, agh_data)

                crossed_chromosomes.append(child_1)
                crossed_chromosomes.append(child_2)
        
        return crossed_chromosomes

    def mutation(self, chromosomes:List[Chromosome], agh_data:AGHData) -> List[Chromosome]:
        flight_num = len(agh_data.flight_data)
        for i in range(PRESERVED_NUM, self.pop_num):
            if random() < self.mutation_rate:
                flight_index = randint(0, flight_num - 1)
                flight = agh_data.flight_data[flight_index]
                is_near_gate = agh_data.gate_data[flight.gate_index].is_near_gate

                vehicle_type = randint(0, 6)
                vehicle_index = randint(0, len(agh_data.vehicle_data[vehicle_type]) - 1)

                if is_near_gate and vehicle_type == VehicleType.ShuttleBus.value:
                    continue

                e, l = self.get_arrival_time_border(\
                    list(VehicleType)[vehicle_type], flight.arrival_time_stamp, is_near_gate)
                vehicle_arrival_time = randint(e, l)

                chromosomes[i].data[flight_index][vehicle_type] = (vehicle_index, vehicle_arrival_time)
        return chromosomes

    def translate(self, chromosome:Chromosome, agh_data:AGHData) -> List[Operation]:
        vehicle_infos = self.generate_vehicle_infos(agh_data.vehicle_data)
        plan = []
        for flight_index, op_infos in enumerate(chromosome.data):
            flight = agh_data.flight_data[flight_index]
            flight_arrival_time = flight.arrival_time_stamp
            gate_index = flight.gate_index
            complete_times = [0] * 7
            is_near_gate = agh_data.gate_data[gate_index].is_near_gate

            for vtype in VehicleType:
                if is_near_gate and vtype == VehicleType.ShuttleBus:
                    continue

                vehicle_index = op_infos[vtype.value][0]
                vehicle_arrival_time = op_infos[vtype.value][1]
                vehicle_info = vehicle_infos[vtype.value][vehicle_index]

                if vehicle_info.is_source_a_park:
                    dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                else:
                    dis = agh_data.get_distance_between_gates(vehicle_info.gate_index, gate_index)

                if dis == -1:
                    return []

                time_cost = ceil(dis/Vehicle.property_of_type[vtype].speed)
                vehicle_lanuch_time = vehicle_arrival_time - time_cost

                if vehicle_lanuch_time < vehicle_info.complete_time:
                    return []

                source = vehicle_info.park_index if vehicle_info.is_source_a_park else vehicle_info.gate_index
                heappush(plan, Operation(vehicle_lanuch_time, vtype, vehicle_index, gate_index, flight_index, \
                    dis, source, vehicle_info.is_source_a_park))

                pre_service_end_time = \
                    self.get_pre_service_end_time(max(flight_arrival_time, vehicle_arrival_time), complete_times, vtype, is_near_gate)

                complete_time = pre_service_end_time + Vehicle.property_of_type[vtype].service_time_cost
                complete_times[vtype.value] = complete_time

                back_dis = agh_data.get_distance_park_to_gate(vehicle_info.park_index, gate_index)
                vehicle_info.consume_capcity(gate_index, complete_time, back_dis)

        solution = []        
        while len(plan) != 0:
            solution.append(heappop(plan))
        return solution

def get_cost(ops:List[Operation], data:AGHData):
    vehicle_caps = [[Vehicle.property_of_type[type].capcity for _ in data.vehicle_data[type.value]] for type in VehicleType]
    cost = []
    cost_sum = cost_max = 0
    vtype_cost = {vtype:Vehicle.property_of_type[vtype].cost for vtype in VehicleType}
    vtype_cap = {vtype:Vehicle.property_of_type[vtype].capcity for vtype in VehicleType}
    for self in ops:
        vtype = self.vehicle_type
        vindex = self.vehicle_index

        tmp_cost = self.distance * vtype_cost[vtype]
        vehicle_caps[vtype.value][vindex] -= 1
        if vehicle_caps[vtype.value][vindex] == 0:
            park_index = data.vehicle_data[vtype.value][vindex].park_index
            back_dis = data.get_distance_park_to_gate(park_index, self.gate_index)
            tmp_cost += back_dis * vtype_cost[vtype]
            vehicle_caps[vtype.value][vindex] = vtype_cap[vtype]
        cost.append(tmp_cost)
        cost_sum += tmp_cost
        cost_max = max(cost_max, tmp_cost)
    return cost, cost_sum, cost_max
